{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo - Applications of MultiAttack (CIFAR10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "import torchvision.utils\n",
    "from torchvision import models\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import torchattacks\n",
    "\n",
    "from models import Target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "\n",
    "cifar10_test  = dsets.CIFAR10(root='./data', train=False,\n",
    "                              download=True, transform=transforms.ToTensor())\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(cifar10_test,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Target().cuda()\n",
    "model.load_state_dict(torch.load(\"./checkpoints/target.pth\"))\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. PGD Random Restart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pgd = torchattacks.PGD(model, eps=4/255, alpha=2/255, steps=4, random_start=True)\n",
    "atk = torchattacks.MultiAttack([pgd]*10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 1.48 % / L2: 0.76154\r"
     ]
    }
   ],
   "source": [
    "# Number of Random Restart = 1\n",
    "pgd.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 1.20 % / L2: 0.76150\r"
     ]
    }
   ],
   "source": [
    "# Number of Random Restart = 10\n",
    "atk.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. CW c-search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_candidate = [1, 0.1, 0.01, 0.001]\n",
    "atk = torchattacks.MultiAttack([torchattacks.CW(model, c, kappa=0, steps=1000, lr=0.01) for c in c_candidate])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 0.00 % / L2: 0.16731\r"
     ]
    }
   ],
   "source": [
    "# Find the best c in c_candidate\n",
    "atk.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Attack on Only Correct Examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note, the accuracy of adversarial attack on only correct examples is usually same as the accuracy of advesarial attack on whole examples except in following cases:\n",
    "    * Using Random Start (below)\n",
    "    * Catstrophic Overfitting (https://arxiv.org/abs/2010.01799)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchattacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "vanila = torchattacks.VANILA(model)\n",
    "pgd = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4, random_start=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "atk = torchattacks.MultiAttack([vanila, pgd])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 0.18 % / L2: 1.25548\r"
     ]
    }
   ],
   "source": [
    "# Clean Accuracy\n",
    "pgd.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 0.20 % / L2: 1.09639\r"
     ]
    }
   ],
   "source": [
    "# Robust Accuracy on only Correct Examples\n",
    "atk.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
